{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tf.data ：数据集的构建与预处理\n",
    "\n",
    "TensorFlow 提供了 tf.data 这一模块，包括了一套灵活的数据集构建 API，能够帮助我们快速、高效地构建数据输入的流水线，尤其适用于数据量巨大的场景。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集对象的建立 \n",
    "\n",
    "\n",
    "tf.data 的核心是 tf.data.Dataset 类，提供了对数据集的高层封装。tf.data.Dataset 由一系列的可迭代访问的元素（element）组成，每个元素包含一个或多个张量。比如说，对于一个由图像组成的数据集，每个元素可以是一个形状为 长×宽×通道数 的图片张量，也可以是由图片张量和图片标签张量组成的元组（Tuple）。\n",
    "\n",
    "最基础的建立 tf.data.Dataset 的方法是使用 tf.data.Dataset.from_tensor_slices() ，适用于数据量较小（能够整个装进内存）的情况。具体而言，如果我们的数据集中的所有元素通过张量的第 0 维，拼接成一个大的张量（例如，前节的 MNIST 数据集的训练集即为一个 [60000, 28, 28, 1] 的张量，表示了 60000 张 28*28 的单通道灰度图像），那么我们提供一个这样的张量或者第 0 维大小相同的多个张量作为输入，即可按张量的第 0 维展开来构建数据集，数据集的元素数量为张量第 0 位的大小。具体示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "X = tf.constant([2013, 2014, 2015, 2016, 2017])\n",
    "Y = tf.constant([12000, 14000, 15000, 16500, 17500])\n",
    "\n",
    "# 也可以使用NumPy数组，效果相同\n",
    "# X = np.array([2013, 2014, 2015, 2016, 2017])\n",
    "# Y = np.array([12000, 14000, 15000, 16500, 17500])\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices((X, Y))\n",
    "\n",
    "for x, y in dataset:\n",
    "    print(x.numpy(), y.numpy())\n",
    "    \n",
    "'''\n",
    "当提供多个张量作为输入时，张量的第 0 维大小必须相同，且必须将多个张量作为元组（Tuple，即使用 Python 中的小括号）拼接并作为输入。\n",
    "\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "\n",
    "(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()\n",
    "train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]\n",
    "mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))\n",
    "\n",
    "for image, label in mnist_dataset:\n",
    "    plt.title(label.numpy())\n",
    "    plt.imshow(image.numpy()[:, :, 0])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 对于特别巨大而无法完整载入内存的数据集，我们可以先将数据集处理为 TFRecord 格式，然后使用 tf.data.TFRocordDataset() 进行载入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集对象的预处理\n",
    "tf.data.Dataset 类为我们提供了多种数据集预处理方法:\n",
    "    1. Dataset.map(f) ：对数据集中的每个元素应用函数 f ，得到一个新的数据集（这部分往往结合 tf.io 进行读写和解码文件， tf.image 进行图像处理）；\n",
    "    2. Dataset.shuffle(buffer_size) ：将数据集打乱（设定一个固定大小的缓冲区（Buffer），取出前 buffer_size 个元素放入，并从缓冲区中随机采样，采样后的数据用后续数据替换）；\n",
    "    3. Dataset.batch(batch_size) ：将数据集分成批次，即对每 batch_size 个元素，使用 tf.stack() 在第 0 维合并，成为一个元素；\n",
    "    4. Dataset.repeat() （重复数据集的元素）、\n",
    "    5. Dataset.reduce() （与 Map 相对的聚合操作）、\n",
    "    6. Dataset.take() （截取数据集中的前若干个元素）等\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset.map() 将所有图片旋转 90 度：\n",
    "def rot90(image, label):\n",
    "    image = tf.image.rot90(image)\n",
    "    return image, label\n",
    "\n",
    "mnist_dataset = mnist_dataset.map(rot90)\n",
    "\n",
    "for image, label in mnist_dataset:\n",
    "    plt.title(label.numpy())\n",
    "    plt.imshow(image.numpy()[:, :, 0])\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "    \n",
    "使用 Dataset.batch() 将数据集划分批次，每个批次的大小为 4：\n",
    "\n",
    "mnist_dataset = mnist_dataset.batch(4)\n",
    "\n",
    "for images, labels in mnist_dataset:    # image: [4, 28, 28, 1], labels: [4]\n",
    "    fig, axs = plt.subplots(1, 4)\n",
    "    for i in range(4):\n",
    "        axs[i].set_title(labels.numpy()[i])\n",
    "        axs[i].imshow(images.numpy()[i, :, :, 0])\n",
    "    plt.show()\n",
    "    \n",
    "使用 Dataset.shuffle() 将数据打散后再设置批次，缓存大小设置为 10000：\n",
    "mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)\n",
    "\n",
    "for images, labels in mnist_dataset:\n",
    "    fig, axs = plt.subplots(1, 4)\n",
    "    for i in range(4):\n",
    "        axs[i].set_title(labels.numpy()[i])\n",
    "        axs[i].imshow(images.numpy()[i, :, :, 0])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用 tf.data 的并行化策略提高训练流程效率 \n",
    "\n",
    " tf.data 的数据集对象为我们提供了 Dataset.prefetch() 方法，使得我们可以让数据集对象 Dataset 在训练时预取出若干个元素，使得在 GPU 训练的同时 CPU 可以准备数据，从而提升训练流程的效率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
    "#此处参数 buffer_size 既可手工设置，也可设置为 tf.data.experimental.AUTOTUNE 从而由 TensorFlow 自动选择合适的数值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 与此类似， Dataset.map() 也可以利用多 GPU 资源，并行化地对数据项进行变换，从而提高效率\n",
    "mnist_dataset = mnist_dataset.map(map_func=rot90, num_parallel_calls=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集元素的获取与使用\n",
    "Keras 支持使用 tf.data.Dataset 直接作为输入。当调用 tf.keras.Model 的 fit() 和 evaluate() 方法时，可以将参数中的输入数据 x 指定为一个元素格式为 (输入数据, 标签数据) 的 Dataset ，并忽略掉参数中的标签数据 y 。例如，对于上述的 MNIST 数据集，常规的 Keras 训练方式是："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)\n",
    "#使用 tf.data.Dataset 后，我们可以直接传入 Dataset ：\n",
    "model.fit(mnist_dataset, epochs=num_epochs)   #由于已经通过 Dataset.batch() 方法划分了数据集的批次，所以这里也无需提供批次的大小。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实例：cats_vs_dogs 图像分类\n",
    "以下代码以猫狗图片二分类任务为示例，展示了使用 tf.data 结合 tf.io 和 tf.image 建立 tf.data.Dataset 数据集，并进行训练和测试的完整过程。数据集可至 这里 下载。使用前须将数据集解压到代码中 data_dir 所设置的目录（此处默认设置为 C:/datasets/cats_vs_dogs ，可根据自己的需求进行修改）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "num_epochs = 10\n",
    "batch_size = 32\n",
    "learning_rate = 0.001\n",
    "data_dir = 'C:/datasets/cats_vs_dogs'\n",
    "train_cats_dir = data_dir + '/train/cats/'\n",
    "train_dogs_dir = data_dir + '/train/dogs/'\n",
    "test_cats_dir = data_dir + '/valid/cats/'\n",
    "test_dogs_dir = data_dir + '/valid/dogs/'\n",
    "\n",
    "def _decode_and_resize(filename, label):\n",
    "    image_string = tf.io.read_file(filename)            # 读取原始文件\n",
    "    image_decoded = tf.image.decode_jpeg(image_string)  # 解码JPEG图片\n",
    "    image_resized = tf.image.resize(image_decoded, [256, 256]) / 255.0\n",
    "    return image_resized, label\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    # 构建训练数据集\n",
    "    train_cat_filenames = tf.constant([train_cats_dir + filename for filename in os.listdir(train_cats_dir)])\n",
    "    train_dog_filenames = tf.constant([train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)])\n",
    "    train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)\n",
    "    train_labels = tf.concat([\n",
    "        tf.zeros(train_cat_filenames.shape, dtype=tf.int32), \n",
    "        tf.ones(train_dog_filenames.shape, dtype=tf.int32)], \n",
    "        axis=-1)\n",
    "\n",
    "    train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))\n",
    "    train_dataset = train_dataset.map(\n",
    "        map_func=_decode_and_resize, \n",
    "        num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
    "    # 取出前buffer_size个数据放入buffer，并从其中随机采样，采样后的数据用后续数据替换\n",
    "    train_dataset = train_dataset.shuffle(buffer_size=23000)    \n",
    "    train_dataset = train_dataset.batch(batch_size)\n",
    "    train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "    model = tf.keras.Sequential([\n",
    "        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(256, 256, 3)),\n",
    "        tf.keras.layers.MaxPooling2D(),\n",
    "        tf.keras.layers.Conv2D(32, 5, activation='relu'),\n",
    "        tf.keras.layers.MaxPooling2D(),\n",
    "        tf.keras.layers.Flatten(),\n",
    "        tf.keras.layers.Dense(64, activation='relu'),\n",
    "        tf.keras.layers.Dense(2, activation='softmax')\n",
    "    ])\n",
    "\n",
    "    model.compile(\n",
    "        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n",
    "        loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
    "        metrics=[tf.keras.metrics.sparse_categorical_accuracy]\n",
    "    )\n",
    "\n",
    "    model.fit(train_dataset, epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用以下代码进行测试：\n",
    " # 构建测试数据集\n",
    "    test_cat_filenames = tf.constant([test_cats_dir + filename for filename in os.listdir(test_cats_dir)])\n",
    "    test_dog_filenames = tf.constant([test_dogs_dir + filename for filename in os.listdir(test_dogs_dir)])\n",
    "    test_filenames = tf.concat([test_cat_filenames, test_dog_filenames], axis=-1)\n",
    "    test_labels = tf.concat([\n",
    "        tf.zeros(test_cat_filenames.shape, dtype=tf.int32), \n",
    "        tf.ones(test_dog_filenames.shape, dtype=tf.int32)], \n",
    "        axis=-1)\n",
    "\n",
    "    test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))\n",
    "    test_dataset = test_dataset.map(_decode_and_resize)\n",
    "    test_dataset = test_dataset.batch(batch_size)\n",
    "\n",
    "    print(model.metrics_names)\n",
    "    print(model.evaluate(test_dataset))\n",
    "    \n",
    "'''\n",
    "通过对以上示例进行性能测试，我们可以感受到 tf.data 的强大并行化性能。通过 prefetch() 的使用和在 map() \n",
    "过程中加入 num_parallel_calls 参数，模型训练的时间可缩减至原来的一半甚至更低。\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TFRecord ：TensorFlow 数据集存储格式\n",
    "\n",
    "TFRecord 是 TensorFlow 中的数据集存储格式。当我们将数据集整理成 TFRecord 格式后，TensorFlow 就可以高效地读取和处理这些数据集，从而帮助我们更高效地进行大规模的模型训练。\n",
    "\n",
    "\n",
    "TFRecord 可以理解为一系列序列化的 tf.train.Example 元素所组成的列表文件，而每一个 tf.train.Example 又由若干个 tf.train.Feature 的字典组成。形式如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset.tfrecords\n",
    "[\n",
    "    {   # example 1 (tf.train.Example)\n",
    "        'feature_1': tf.train.Feature,\n",
    "        ...\n",
    "        'feature_k': tf.train.Feature\n",
    "    },\n",
    "    ...\n",
    "    {   # example N (tf.train.Example)\n",
    "        'feature_1': tf.train.Feature,\n",
    "        ...\n",
    "        'feature_k': tf.train.Feature\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了将形式各样的数据集整理为 TFRecord 格式，我们可以对数据集中的每个元素进行以下步骤：\n",
    "    1.读取该数据元素到内存；\n",
    "\n",
    "    2.将该元素转换为 tf.train.Example 对象（每一个 tf.train.Example 由若干个 tf.train.Feature 的字典组成，因此需要先建立 Feature 的字典）；\n",
    "\n",
    "    3.将该 tf.train.Example 对象序列化为字符串，并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件。\n",
    "    \n",
    "而读取 TFRecord 数据则可按照以下步骤：\n",
    "\n",
    "    1.通过tf.data.TFRecordDataset 读入原始的 TFRecord 文件（此时文件中的 tf.train.Example 对象尚未被反序列化），获得一个 tf.data.Dataset 数据集对象；\n",
    "\n",
    "    2.通过 Dataset.map 方法，对该数据集对象中的每一个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 函数，从而实现反序列化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 将数据集存储为 TFRecord 文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''首先，下载数据集 并解压到 data_dir ，初始化数据集的图片文件名列表及标签。'''\n",
    "\n",
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "data_dir = 'C:/datasets/cats_vs_dogs'\n",
    "train_cats_dir = data_dir + '/train/cats/'\n",
    "train_dogs_dir = data_dir + '/train/dogs/'\n",
    "tfrecord_file = data_dir + '/train/train.tfrecords'\n",
    "\n",
    "train_cat_filenames = [train_cats_dir + filename for filename in os.listdir(train_cats_dir)]\n",
    "train_dog_filenames = [train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)]\n",
    "train_filenames = train_cat_filenames + train_dog_filenames\n",
    "train_labels = [0] * len(train_cat_filenames) + [1] * len(train_dog_filenames)  # 将 cat 类的标签设为0，dog 类的标签设为1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'然后，通过以下代码，迭代读取每张图片，建立 tf.train.Feature 字典和 tf.train.Example 对象，序列化并写入 TFRecord 文件。'\n",
    "\n",
    "with tf.io.TFRecordWriter(tfrecord_file) as writer:\n",
    "    for filename, label in zip(train_filenames, train_labels):\n",
    "        image = open(filename, 'rb').read()     # 读取数据集图片到内存，image 为一个 Byte 类型的字符串\n",
    "        feature = {                             # 建立 tf.train.Feature 字典\n",
    "            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 图片是一个 Bytes 对象\n",
    "            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))   # 标签是一个 Int 对象\n",
    "        }\n",
    "        example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Example\n",
    "        writer.write(example.SerializeToString())   # 将Example序列化并写入 TFRecord 文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "'''\n",
    "值得注意的是， tf.train.Feature 支持三种数据格式：\n",
    "\n",
    "tf.train.BytesList ：字符串或原始 Byte 文件（如图片），通过 bytes_list 参数传入一个由字符串数组初始化的 tf.train.BytesList 对象；\n",
    "\n",
    "tf.train.FloatList ：浮点数，通过 float_list 参数传入一个由浮点数数组初始化的 tf.train.FloatList 对象；\n",
    "\n",
    "tf.train.Int64List ：整数，通过 int64_list 参数传入一个由整数数组初始化的 tf.train.Int64List 对象。\n",
    "\n",
    "如果只希望保存一个元素而非数组，传入一个只有一个元素的数组即可。\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取 TFRecord 文件\n",
    "\n",
    "我们可以通过以下代码，读取之前建立的 train.tfrecords 文件，并通过 Dataset.map 方法，使用 tf.io.parse_single_example 函数对数据集中的每一个序列化的 tf.train.Example 对象解码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_dataset = tf.data.TFRecordDataset(tfrecord_file)    # 读取 TFRecord 文件\n",
    "\n",
    "feature_description = { # 定义Feature结构，告诉解码器每个Feature的类型是什么\n",
    "    'image': tf.io.FixedLenFeature([], tf.string),\n",
    "    'label': tf.io.FixedLenFeature([], tf.int64),\n",
    "}\n",
    "\n",
    "def _parse_example(example_string): # 将 TFRecord 文件中的每一个序列化的 tf.train.Example 解码\n",
    "    feature_dict = tf.io.parse_single_example(example_string, feature_description)\n",
    "    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])    # 解码JPEG图片\n",
    "    return feature_dict['image'], feature_dict['label']\n",
    "\n",
    "dataset = raw_dataset.map(_parse_example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的 feature_description 类似于一个数据集的 “描述文件”，通过一个由键值对组成的字典，告知 tf.io.parse_single_example 函数每个 tf.train.Example 数据项有哪些 Feature，以及这些 Feature 的类型、形状等属性。 tf.io.FixedLenFeature 的三个输入参数 shape 、 dtype 和 default_value （可省略）为每个 Feature 的形状、类型和默认值。这里我们的数据项都是单个的数值或者字符串，所以 shape 为空数组。\n",
    "\n",
    "运行以上代码后，我们获得一个数据集对象 dataset ，这已经是一个可以用于训练的 tf.data.Dataset 对象了！我们从该数据集中读取元素并输出验证："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "\n",
    "for image, label in dataset:\n",
    "    plt.title('cat' if label == 0 else 'dog')\n",
    "    plt.imshow(image.numpy())\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 应用实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNISTLoader():\n",
    "    def __init__(self):\n",
    "        mnist = tf.keras.datasets.mnist\n",
    "        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()\n",
    "        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道\n",
    "        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]\n",
    "        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]\n",
    "        self.train_label = self.train_label.astype(np.int32)    # [60000]\n",
    "        self.test_label = self.test_label.astype(np.int32)      # [10000]\n",
    "        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]\n",
    "\n",
    "    def get_batch(self, batch_size):\n",
    "        # 从数据集中随机取出batch_size个元素并返回\n",
    "        index = np.random.randint(0, self.num_train_data, batch_size)\n",
    "        return self.train_data[index, :], self.train_label[index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataLoader():\n",
    "    def __init__(self):\n",
    "        path = tf.keras.utils.get_file('nietzsche.txt',\n",
    "            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')\n",
    "        with open(path, encoding='utf-8') as f:\n",
    "            self.raw_text = f.read().lower()\n",
    "        self.chars = sorted(list(set(self.raw_text)))\n",
    "        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))\n",
    "        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))\n",
    "        self.text = [self.char_indices[c] for c in self.raw_text]\n",
    "\n",
    "    def get_batch(self, seq_length, batch_size):\n",
    "        seq = []\n",
    "        next_char = []\n",
    "        for i in range(batch_size):\n",
    "            index = np.random.randint(0, len(self.text) - seq_length)\n",
    "            seq.append(self.text[index:index+seq_length])\n",
    "            next_char.append(self.text[index+seq_length])\n",
    "        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
